/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.solr.ltr;

import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FloatDocValuesField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.solr.SolrTestCase;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.ValueFeature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.model.ModelException;
import org.apache.solr.ltr.model.TestLinearModel;
import org.apache.solr.ltr.norm.IdentityNormalizer;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.ltr.norm.NormalizerException;
import org.junit.Test;

public class TestLTRScoringQuery extends SolrTestCase {

  public static final SolrResourceLoader solrResourceLoader =
      new SolrResourceLoader(Path.of("").toAbsolutePath());

  private IndexSearcher getSearcher(IndexReader r) {
    final IndexSearcher searcher = newSearcher(r, false, false);
    return searcher;
  }

  private static List<Feature> makeFeatures(int[] featureIds) {
    final List<Feature> features = new ArrayList<>();
    for (final int i : featureIds) {
      Map<String, Object> params = new HashMap<String, Object>();
      params.put("value", i);
      final Feature f =
          Feature.getInstance(solrResourceLoader, ValueFeature.class.getName(), "f" + i, params);
      f.setIndex(i);
      features.add(f);
    }
    return features;
  }

  private static List<Feature> makeFilterFeatures(int[] featureIds) {
    final List<Feature> features = new ArrayList<>();
    for (final int i : featureIds) {
      Map<String, Object> params = new HashMap<String, Object>();
      params.put("value", i);
      final Feature f =
          Feature.getInstance(solrResourceLoader, ValueFeature.class.getName(), "f" + i, params);
      f.setIndex(i);
      features.add(f);
    }
    return features;
  }

  private LTRScoringQuery.ModelWeight performQuery(
      TopDocs hits, IndexSearcher searcher, int docid, LTRScoringQuery model)
      throws IOException, ModelException {
    final List<LeafReaderContext> leafContexts = searcher.getTopReaderContext().leaves();
    final int n = ReaderUtil.subIndex(hits.scoreDocs[0].doc, leafContexts);
    final LeafReaderContext context = leafContexts.get(n);
    final int deBasedDoc = hits.scoreDocs[0].doc - context.docBase;

    final Weight weight = searcher.createWeight(searcher.rewrite(model), ScoreMode.COMPLETE, 1);
    final Scorer scorer = weight.scorer(context);

    // rerank using the field final-score
    scorer.iterator().advance(deBasedDoc);
    scorer.score();

    // assertEquals(42.0f, score, 0.0001);
    // assertTrue(weight instanceof AssertingWeight);
    // (AssertingIndexSearcher)
    assertTrue(weight instanceof LTRScoringQuery.ModelWeight);
    final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) weight;
    return modelWeight;
  }

  @Test
  public void testLTRScoringQueryEquality() throws ModelException {
    final List<Feature> features = makeFeatures(new int[] {0, 1, 2});
    final List<Normalizer> norms =
        new ArrayList<Normalizer>(
            Collections.nCopies(features.size(), IdentityNormalizer.INSTANCE));
    final List<Feature> allFeatures = makeFeatures(new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
    final Map<String, Object> modelParams = TestLinearModel.makeFeatureWeights(features);

    final LTRScoringModel algorithm1 =
        TestLinearModel.createLinearModel(
            "testModelName", features, norms, "testStoreName", allFeatures, modelParams);

    final LTRScoringQuery m0 = new LTRScoringQuery(algorithm1);

    final HashMap<String, String[]> externalFeatureInfo = new HashMap<>();
    externalFeatureInfo.put("queryIntent", new String[] {"company"});
    externalFeatureInfo.put("user_query", new String[] {"abc"});
    final LTRScoringQuery m1 = new LTRScoringQuery(algorithm1, externalFeatureInfo, null);

    final HashMap<String, String[]> externalFeatureInfo2 = new HashMap<>();
    externalFeatureInfo2.put("user_query", new String[] {"abc"});
    externalFeatureInfo2.put("queryIntent", new String[] {"company"});
    int totalPoolThreads = 10, numThreadsPerRequest = 10;
    LTRThreadModule threadManager = new LTRThreadModule(totalPoolThreads, numThreadsPerRequest);
    final LTRScoringQuery m2 = new LTRScoringQuery(algorithm1, externalFeatureInfo2, threadManager);

    // Models with same algorithm and efis, just in different order should be the same
    assertEquals(m1, m2);
    assertEquals(m1.hashCode(), m2.hashCode());

    // Models with same algorithm, but different efi content should not match
    assertNotEquals(m1, m0);
    assertNotEquals(m1.hashCode(), m0.hashCode());

    final LTRScoringModel algorithm2 =
        TestLinearModel.createLinearModel(
            "testModelName2", features, norms, "testStoreName", allFeatures, modelParams);
    final LTRScoringQuery m3 = new LTRScoringQuery(algorithm2);

    assertNotEquals(m1, m3);
    assertNotEquals(m1.hashCode(), m3.hashCode());

    final LTRScoringModel algorithm3 =
        TestLinearModel.createLinearModel(
            "testModelName", features, norms, "testStoreName3", allFeatures, modelParams);
    final LTRScoringQuery m4 = new LTRScoringQuery(algorithm3);

    assertNotEquals(m1, m4);
    assertNotEquals(m1.hashCode(), m4.hashCode());
  }

  @Test
  public void testLTRScoringQuery() throws IOException, ModelException {
    final Directory dir = newDirectory();
    final RandomIndexWriter w = new RandomIndexWriter(random(), dir);

    Document doc = new Document();
    doc.add(newStringField("id", "0", Field.Store.YES));
    doc.add(newTextField("field", "wizard the the the the the oz", Field.Store.NO));
    doc.add(new FloatDocValuesField("final-score", 1.0f));

    w.addDocument(doc);
    doc = new Document();
    doc.add(newStringField("id", "1", Field.Store.YES));
    // 1 extra token, but wizard and oz are close;
    doc.add(newTextField("field", "wizard oz the the the the the the", Field.Store.NO));
    doc.add(new FloatDocValuesField("final-score", 2.0f));
    w.addDocument(doc);

    final IndexReader r = w.getReader();
    w.close();

    // Do ordinary BooleanQuery:
    final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
    bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
    bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
    final IndexSearcher searcher = getSearcher(r);
    // first run the standard query
    final TopDocs hits = searcher.search(bqBuilder.build(), 10);
    assertEquals(2, hits.totalHits.value());
    assertEquals("0", searcher.storedFields().document(hits.scoreDocs[0].doc).get("id"));
    assertEquals("1", searcher.storedFields().document(hits.scoreDocs[1].doc).get("id"));

    List<Feature> features = makeFeatures(new int[] {0, 1, 2});
    final List<Feature> allFeatures = makeFeatures(new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
    List<Normalizer> norms =
        new ArrayList<Normalizer>(
            Collections.nCopies(features.size(), IdentityNormalizer.INSTANCE));
    LTRScoringModel ltrScoringModel =
        TestLinearModel.createLinearModel(
            "test",
            features,
            norms,
            "test",
            allFeatures,
            TestLinearModel.makeFeatureWeights(features));

    LTRScoringQuery.ModelWeight modelWeight =
        performQuery(hits, searcher, hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel));
    assertEquals(3, modelWeight.getModelFeatureValuesNormalized().length);

    for (int i = 0; i < 3; i++) {
      assertEquals(i, modelWeight.getModelFeatureValuesNormalized()[i], 0.0001);
    }
    int[] posVals = new int[] {0, 1, 2};
    int pos = 0;
    for (LTRScoringQuery.FeatureInfo fInfo : modelWeight.getAllFeaturesInStore()) {
      if (fInfo == null) {
        continue;
      }
      assertEquals(posVals[pos], fInfo.getValue(), 0.0001);
      assertEquals("f" + posVals[pos], fInfo.getName());
      pos++;
    }

    final int[] mixPositions = new int[] {8, 2, 4, 9, 0};
    features = makeFeatures(mixPositions);
    norms =
        new ArrayList<Normalizer>(
            Collections.nCopies(features.size(), IdentityNormalizer.INSTANCE));
    ltrScoringModel =
        TestLinearModel.createLinearModel(
            "test",
            features,
            norms,
            "test",
            allFeatures,
            TestLinearModel.makeFeatureWeights(features));

    modelWeight =
        performQuery(hits, searcher, hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel));
    assertEquals(mixPositions.length, modelWeight.getModelFeatureWeights().length);

    for (int i = 0; i < mixPositions.length; i++) {
      assertEquals(mixPositions[i], modelWeight.getModelFeatureValuesNormalized()[i], 0.0001);
    }

    final ModelException expectedModelException =
        new ModelException("no features declared for model test");
    final int[] noPositions = new int[] {};
    features = makeFeatures(noPositions);
    norms =
        new ArrayList<Normalizer>(
            Collections.nCopies(features.size(), IdentityNormalizer.INSTANCE));
    try {
      ltrScoringModel =
          TestLinearModel.createLinearModel(
              "test",
              features,
              norms,
              "test",
              allFeatures,
              TestLinearModel.makeFeatureWeights(features));
      fail("unexpectedly got here instead of catching " + expectedModelException);
      modelWeight =
          performQuery(hits, searcher, hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel));
      assertEquals(0, modelWeight.getModelFeatureWeights().length);
    } catch (ModelException actualModelException) {
      assertEquals(expectedModelException.toString(), actualModelException.toString());
    }

    // test normalizers
    features = makeFilterFeatures(mixPositions);
    final Normalizer norm =
        new Normalizer() {

          @Override
          public float normalize(float value) {
            return 42.42f;
          }

          @Override
          public LinkedHashMap<String, Object> paramsToMap() {
            return null;
          }

          @Override
          protected void validate() throws NormalizerException {}
        };
    norms = new ArrayList<Normalizer>(Collections.nCopies(features.size(), norm));
    final LTRScoringModel normMeta =
        TestLinearModel.createLinearModel(
            "test",
            features,
            norms,
            "test",
            allFeatures,
            TestLinearModel.makeFeatureWeights(features));

    modelWeight =
        performQuery(hits, searcher, hits.scoreDocs[0].doc, new LTRScoringQuery(normMeta));
    normMeta.normalizeFeaturesInPlace(modelWeight.getModelFeatureValuesNormalized());
    assertEquals(mixPositions.length, modelWeight.getModelFeatureWeights().length);
    for (int i = 0; i < mixPositions.length; i++) {
      assertEquals(42.42f, modelWeight.getModelFeatureValuesNormalized()[i], 0.0001);
    }
    r.close();
    dir.close();
  }
}
